{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "prompt-peoples",
   "metadata": {},
   "source": [
    "## Autogeneration of license plates images for training LPRNet\n",
    "\n",
    "This notebook provide a skeleton to generate random license plate images for training LPRNet for license plate recognition. Auto-generation of synthetic training data constitutes an easy strategy of generating large datasets in a cheap manner, alleviating the burden of collecting and annotating real data. In contrast to many other computer vision tasks, training a network for text recognition can rely on synthetic data alone and achieve a good recognition performance on real-life data. \n",
    "\n",
    "In order to use thiss notebook, clean license plate templates (with no text) and predefined text fonts (`.ttf`) for this purpose should be prepared in advance and placed under the following paths:\n",
    " - `./plates/` - path to clean license plates:\n",
    " - `./fonts` - path to clean license plates\n",
    "\n",
    "<br>\n",
    "There are several knobs that can be tempered with to control the characteristics of the generated license plates, such as:\n",
    " \n",
    " - license plate text length\n",
    " - coordinates within the plate to write the text at\n",
    " - spereations between the license plate text characters\n",
    " - color difference between each character in the license plate\n",
    "\n",
    "**_Notes_**:\n",
    " - The current flow is designed to generate license plates for training the model to recognize Israeli license plates. \n",
    " - Hailo's LPRNet was trained on an autogenearted dataset of 4 million images. Training the model on smaller datasets resulted in a significant accuracy drop\n",
    " - In addition to the large amount of training data, we have also used data augmentations in the training phase, which we found to greatly aid both recongnition performance and training convergence."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "interesting-bones",
   "metadata": {},
   "outputs": [],
   "source": [
    "#imports\n",
    "import random\n",
    "import os\n",
    "import numpy as np\n",
    "from IPython.display import display\n",
    "from IPython.display import Image as Display\n",
    "from PIL import ImageDraw, ImageFont, Image\n",
    "import imgaug.augmenters as iaa\n",
    "import cv2\n",
    "import time\n",
    "import multiprocessing\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "threaded-round",
   "metadata": {},
   "outputs": [],
   "source": [
    "# global configurations\n",
    "clean_plates = [os.path.join('plates', x) for x in os.listdir('plates')]\n",
    "fonts = [os.path.join('fonts', x) for x in os.listdir('fonts')]\n",
    "\n",
    "font_size = [40, 60]               # font size min,max\n",
    "plate_length = [7, 8]              # plate length min,max\n",
    "basic_color_diff = 100             # difference in color form black/white\n",
    "right_margin = [35, 45]            # margin to start the text from\n",
    "top_margin = [8, 12]               # margin to start the text from\n",
    "margin_between_letters = [10, 25]  # margin between two letters\n",
    "\n",
    "\n",
    "# per letter configurations\n",
    "difference_in_size = 5    # font size diff between letters\n",
    "difference_in_color = 50  # color diff between letters\n",
    "difference_in_loc = 5     # diff in location between letters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "light-block",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# define some util functions\n",
    "\n",
    "# for each clean plate, define the (relative) box coords to be distorted\n",
    "box_coord_to_distort = {0: [0.115, 0.97, 0.13, 0.90],\n",
    "                        1: [0.115, 0.945, 0.18, 0.85],\n",
    "                        2: [0.105, 0.98, 0.08, 0.92],\n",
    "                        3: [0.13, 0.96, 0.18, 0.85],\n",
    "                        4: [0.04, 0.96, 0.2, 0.80],\n",
    "                        5: [0.13, 0.96, 0.15, 0.77],\n",
    "                        }\n",
    "\n",
    "def distort_background(img, xp0, xp1, yp0, yp1,\n",
    "                       shape=(1000, 221),\n",
    "                       mean_c = (175, 140, 55),\n",
    "                       std_c = (10, 15, 40),\n",
    "                       wfact=30, hfact=60\n",
    "                       ):\n",
    "    '''\n",
    "    function gets a clean license plate image and relative box coordinates within it and distorts this area\n",
    "    Args:\n",
    "        img: clean license plate image\n",
    "        xp0, xp1, yp0, yp1: relative box coordinates defining the crop to be distorted\n",
    "        shape: resize the image to this value. box coordinates should be compatible with this value\n",
    "        mean_c, std_c: define the min/max color value range to be sampled from and added\n",
    "                       to each RGB channel in the cropped area\n",
    "                       such that the min value is mean_c-std_c\n",
    "                       and the max value is mean_c-std_c\n",
    "        wfact, hfact: used to also downscale and upscale the cropped area dimensions to \"pixelize\" the distortion\n",
    "    '''\n",
    "    \n",
    "    # resize image\n",
    "    img = np.array(Image.fromarray(np.squeeze(img)).resize((shape)), dtype=np.uint8)\n",
    "    \n",
    "    # get absolute box coords in image to crop out and distort               \n",
    "    xmin, xmax, ymin, ymax =\\\n",
    "        int(xp0*img.shape[1]), int(xp1*img.shape[1]),\\\n",
    "        int(yp0*img.shape[0]), int(yp1*img.shape[0])\n",
    "    \n",
    "    cut = img[ymin:ymax,xmin:xmax, :]\n",
    "    new_width = int(cut.shape[0]/wfact)\n",
    "    new_height = int(cut.shape[1]/hfact)\n",
    "\n",
    "    # downsample cropped area to distort\n",
    "    cut_resized = np.array(Image.fromarray(np.squeeze(cut)).resize((new_width, new_height)), dtype=np.uint8)\n",
    "    h,w,c = cut_resized.shape\n",
    "    distort = np.zeros((h,w,c), dtype=np.uint8)\n",
    "\n",
    "    # add color distortion\n",
    "    for i in range(distort.shape[-1]):\n",
    "        distort[:,:,i] += np.random.randint(mean_c[i]-std_c[i], mean_c[i]+std_c[i], size=(h, w), dtype=np.uint8)\n",
    "    \n",
    "    new_img = img.copy().transpose((2,0,1))\n",
    "\n",
    "    # upsample distorted area back to original shape, and plug into image\n",
    "    new_img[:,ymin:ymax,xmin:xmax] = np.array(Image.fromarray(np.squeeze(distort)).resize(cut.shape[:2],Image.NEAREST), dtype=np.uint8).transpose()\n",
    "    new_img = new_img.transpose((1,2,0))\n",
    "    \n",
    "    return Image.fromarray(np.squeeze(new_img))\n",
    "\n",
    "\n",
    "def get_random_string():\n",
    "    '''\n",
    "    get the random license plate number\n",
    "    '''\n",
    "    def get_random_string_v1():\n",
    "        lp_len = random.choice(plate_length)\n",
    "        chars = [str(x) for x in range(10)]\n",
    "        sep = ['-',' ']\n",
    "        groups = []\n",
    "        if lp_len == 7:\n",
    "            groups.extend(random.choices(chars, k=2))\n",
    "            groups.extend(random.choices(sep,weights=[85,15], k=1))\n",
    "            groups.extend(random.choices(chars, k=3))\n",
    "            groups.extend(random.choices(sep,weights=[85,15], k=1))\n",
    "            groups.extend(random.choices(chars, k=2))\n",
    "        elif lp_len == 8:\n",
    "            groups.extend(random.choices(chars, k=3))\n",
    "            groups.extend(random.choices(sep,weights=[85,15], k=1))\n",
    "            groups.extend(random.choices(chars, k=2))\n",
    "            groups.extend(random.choices(sep,weights=[85,15], k=1))\n",
    "            groups.extend(random.choices(chars, k=3))\n",
    "        return groups\n",
    "\n",
    "    def get_random_string_v2():\n",
    "        lp = random.choice(plate_length)\n",
    "        num = [str(x) for x in range(10)]\n",
    "        t = []\n",
    "        for i in range(lp+2):\n",
    "            if (((i == 3 or i == 6) and lp == 8) or\n",
    "                ((i == 2 or i == 6) and lp == 7)):\n",
    "                t.append(random.choice([' ', '-', '.']))\n",
    "            else:\n",
    "                t.append(random.choice(num))\n",
    "        return t\n",
    "    \n",
    "    def get_random_string_v3():\n",
    "        lp_len = random.choice(plate_length)\n",
    "        chars = [str(x) for x in range(10)] + [' ', '-']\n",
    "        p = [0.9 / 11] * 11 + [0.1]\n",
    "        return np.random.choice(chars, lp_len, p=p)\n",
    "    \n",
    "    random_string_funcs = [get_random_string_v1,\n",
    "                           get_random_string_v2,\n",
    "                           get_random_string_v3]\n",
    "\n",
    "    return np.random.choice(random_string_funcs)()\n",
    "\n",
    "def write_text(img, text_list):\n",
    "    '''\n",
    "    Randomly add the generated plate number (text_list) into the clean plate (img)\n",
    "    '''\n",
    "    img_pil = Image.fromarray(img)\n",
    "    font_sz = int(random.choice(np.linspace(font_size[0], font_size[1], 1 + font_size[1] - font_size[0])))\n",
    "    basic_color = [random.choice(np.linspace(0, basic_color_diff, basic_color_diff+1)) for i in range(3)]\n",
    "    fontc = random.choice(fonts)\n",
    "    draw = ImageDraw.Draw(img_pil)\n",
    "    w, h = img_pil.size\n",
    "    rmargin = int(random.choice(np.linspace(right_margin[0], right_margin[1])))\n",
    "    tmargin = int(random.choice(np.linspace(top_margin[0], top_margin[1])))\n",
    "    f = ImageFont.truetype(fontc, font_sz)\n",
    "    while True:\n",
    "        tw, th = draw.textsize(''.join(text_list), font=f)\n",
    "        if tw < w * 0.8:\n",
    "            break\n",
    "        font_sz -= 1\n",
    "        f = ImageFont.truetype(fontc, font_sz)\n",
    "\n",
    "    margin = tw / len(text_list)*1.05\n",
    "    for i, t in enumerate(text_list):\n",
    "        f = ImageFont.truetype(fontc, font_sz + random.choice([-1, 1]) * random.choice([x for x in range(difference_in_size)]))\n",
    "        color = [int(x + random.choice([-1, 1]) * random.choice([x for x in range(difference_in_color)])) for x in basic_color]\n",
    "        draw.text((rmargin + i * margin + random.choice([-1, 1]) * random.choice([x for x in range(difference_in_loc)]),\n",
    "                   tmargin + random.choice([-1, 1]) * random.choice([x for x in range(difference_in_loc)])),\n",
    "                  t, fill=tuple(color), font=f)\n",
    "    return np.array(img_pil, np.uint8)\n",
    "\n",
    "def gen_random_plate(shape=(1000, 221), final_shape=(300, 75)):\n",
    "    '''\n",
    "    get an autogenerated license plate image and its random plate number\n",
    "    '''\n",
    "    \n",
    "    # get random plate number\n",
    "    lp = get_random_string()\n",
    "    lp_text = ''.join(lp).replace('.','').replace(' ','').replace('-','').replace('~','')\n",
    "\n",
    "    # get clean plate\n",
    "    clean_plate = random.choice(clean_plates)\n",
    "    img = Image.open(clean_plate)\n",
    "    \n",
    "    if random.random() >= 0.5:\n",
    "        # distort a certain area of the plate's background\n",
    "        img = distort_background(img, *box_coord_to_distort[clean_plates.index(clean_plate)])\n",
    "    \n",
    "    img = Image.fromarray(np.squeeze(img)).resize(final_shape)\n",
    "\n",
    "    # write the text\n",
    "    img_pil = write_text(np.array(img), lp)\n",
    "    img = np.array(img_pil)\n",
    "    \n",
    "    if random.random() >= 0.5:\n",
    "        # convert to BW\n",
    "        img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)\n",
    "\n",
    "    return img, lp_text\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "prime-update",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# display some autogenerated examples\n",
    "\n",
    "num_to_display = 20\n",
    "lp_size = (300, 75) # (width x height)\n",
    "\n",
    "for _ in range(num_to_display):\n",
    "    img, seq = gen_random_plate(final_shape=lp_size)\n",
    "    img = Image.fromarray(img)\n",
    "    print(seq)\n",
    "    display(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "periodic-province",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create images with multi_process\n",
    "import time\n",
    "import multiprocessing\n",
    "from tqdm import tqdm\n",
    "\n",
    "# autogenerated license plate image size (width x height)\n",
    "lp_size = (300, 75)\n",
    "\n",
    "def create_images(root_dir, i, k):\n",
    "    root_dir = root_dir.format(i)\n",
    "    if not os.path.exists(root_dir):\n",
    "        os.mkdir(root_dir)\n",
    "    for _ in tqdm(range(k)):\n",
    "        img, seq = gen_random_plate(final_shape=lp_size)\n",
    "        img = Image.fromarray(img)\n",
    "        img.save(root_dir + seq + '.png')\n",
    "    \n",
    "root_dir = \"./lp_autogenerate/train_batch{}/\"\n",
    "nproc = 16\n",
    "train_size = 4000\n",
    "\n",
    "start = time.perf_counter()\n",
    "processes = []\n",
    "for i in range(nproc):\n",
    "    p = multiprocessing.Process(target=create_images, args = [root_dir, i, int(train_size/nproc)])\n",
    "    p.start()\n",
    "    processes.append(p)\n",
    "for p in processes:\n",
    "    print(F\"{p}\")\n",
    "    p.join()\n",
    "end = time.perf_counter()\n",
    "\n",
    "print(f'Finished in {round(end-start, 2)} second(s)')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
